UVA1407 Caves

一道标准的树形dp。

dp[u][i][0/1]dp[u][i][0/1] 表示以 uu 为根的子树 , 探索 ii 个洞穴,是否回到 uu

显然有:

{dp[u][j][0]=min(dp[u][k][1]+dp[v][jk][0]+w)先走到 v 点,再在 v 的子树绕一圈后dp[u][j][0]=min(dp[u][k][0]+dp[v][jk][1]+2w)先走到 v 点,在 v 的子树绕一圈后,回到 u 点,再走向另一个儿子dp[u][j][1]=min(dp[u][k][1]+dp[v][jk][1]+2w)先走到 v 点,在 v 的子树绕一圈后,再从 v 点走回来\begin{cases} dp[ u ][ j ][ 0 ] = min( dp[ u ][ k ][ 1 ] + dp[ v ][ j - k ][ 0 ] + w )&\text{先走到 $v$ 点,再在 $v$ 的子树绕一圈后}\\ dp[ u ][ j ][ 0 ] = min( dp[ u ][ k ][ 0 ] + dp[ v ][ j - k ][ 1 ] + 2 * w )&\text{先走到 $v$ 点,在 $v$ 的子树绕一圈后,回到 $u$ 点,再走向另一个儿子}\\ dp[ u ][ j ][ 1 ] = min( dp[ u ][ k ][ 1 ] + dp[ v ][ j - k ][ 1 ] + 2 * w )&\text{先走到 $v$ 点,在 $v$ 的子树绕一圈后,再从 $v$ 点走回来} \end{cases}

显然 dpdp 数组是单增的,可以二分寻找答案,复杂度Θ(n3+qlog2n)\Theta(n^3+q\log_2n)

#include <cstdio>
#include <vector>
#include <cstring>
using namespace std;

const int MAXN = 500;
int n , ca , u , v , w , q , x , In[ MAXN + 5 ] , Size[ MAXN + 5 ];
int dp[ MAXN + 5 ][ MAXN + 5 ][ 2 ];
struct node{
	int v , w;
};
vector< node > Graph[ MAXN + 5 ];

void dfs( int u ) {
	Size[ u ] = 1;
	dp[ u ][ 0 ][ 0 ] = dp[ u ][ 1 ][ 0 ] = 0;
	dp[ u ][ 0 ][ 1 ] = dp[ u ][ 1 ][ 1 ] = 0;
	
	for( int i = 0 ; i < Graph[ u ].size( ) ; i ++ ) {
		int v = Graph[ u ][ i ].v , w = Graph[ u ][ i ].w;
		dfs( v ); Size[ u ] += Size[ v ];
		
		for( int j = Size[ u ] ; j >= 1 ; j -- )
			for( int k = j ; k >= j - Size[ v ] ; k -- ) {
				dp[ u ][ j ][ 0 ] = min( dp[ u ][ j ][ 0 ] , dp[ u ][ k ][ 1 ] + dp[ v ][ j - k ][ 0 ] + w );
				dp[ u ][ j ][ 0 ] = min( dp[ u ][ j ][ 0 ] , dp[ u ][ k ][ 0 ] + dp[ v ][ j - k ][ 1 ] + 2 * w );
				dp[ u ][ j ][ 1 ] = min( dp[ u ][ j ][ 1 ] , dp[ u ][ k ][ 1 ] + dp[ v ][ j - k ][ 1 ] + 2 * w );
			}
	}
}

int solve( int que ) {
	int l = 1 , r = n , Ans;
	while( l <= r ) {
		int Mid = ( l + r ) / 2;
		if( dp[ 1 ][ Mid ][ 0 ] <= que || dp[ 1 ][ Mid ][ 1 ] <= que )
			l = Mid + 1 , Ans = Mid;
		else
			r = Mid - 1;
	}
	return Ans;
}

int main( ) {
	while( scanf("%d",&n) && n ) {
		printf("Case %d:\n",++ca);
		memset( dp , 0x3f , sizeof( dp ) );
		for( int i = 1 ; i <= n ; i ++ ) Graph[ i ].clear();
		
		for( int i = 1 ; i <= n - 1 ; i ++ ) {
			scanf("%d %d %d",&u,&v,&w); u ++ , v ++;
			Graph[ v ].push_back( { u , w } );	
		}
		dfs( 1 );		
		scanf("%d",&q);
		while( q -- ) {
			scanf("%d",&x);
			printf("%d\n",solve( x ));
		}
	}
	return 0;
}